{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# RMSProp\n",
    "\n",
    "In the experiment in the Adagrad section, the learning rate of each element in the independent variable\n",
    "of the objective function declines (or remains unchanged) during iteration because the variable $s_t$ in\n",
    "the denominator is increased by the square by element operation of the mini-batch stochastic gradient,\n",
    "adjusting the learning rate. Therefore, when the learning rate declines very fast during early iteration, yet\n",
    "the current solution is still not desirable, Adagrad might have difficulty finding a useful solution because\n",
    "the learning rate will be too small at later stages of iteration. To tackle this problem, the RMSProp\n",
    "algorithm made a small modification to Adagrad.\n",
    "\n",
    "## 8.6.1 The Algorithm\n",
    "\n",
    "We introduced EWMA (exponentially weighted moving average) in the Momentum section. Unlike in\n",
    "Adagrad, the state variable $s_t$ is the sum of the square by element all the mini-batch stochastic gradients\n",
    "$g_t$ up to the time step t, RMSProp uses the EWMA on the square by element results of these gradients.\n",
    "Specifically, given the hyperparameter 0 ≤ $ \\gamma $ < 1, RMSProp is computed at time step t > 0.\n",
    "\n",
    "$$ \\begin{aligned} \\mathbf{s}_t \\leftarrow \\gamma \\mathbf{s}_{t-1} + (1 - \\gamma) \\mathbf{g}_t * \\mathbf{g}_t \\end{aligned} $$"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Like Adagrad, RMSProp re-adjusts the learning rate of each element in the independent variable of the\n",
    "objective function with element operations and then updates the independent variable.\n",
    "\n",
    "$$ \\begin{aligned} \\mathbf{x}_t \\leftarrow  \\mathbf{x}_{t-1} (\\frac{\\eta}{\\sqrt{\\mathbf{s}_t + \\epsilon}}) * \\mathbf{g}_t \\end{aligned} $$ \n",
    "\n",
    "Here, η is the learning rate while ε is a constant added to maintain numerical stability, such as $10 ^ {−6}$ .\n",
    "Because the state variable of RMSProp is an EWMA of the squared term $g_t * g_t$ , it can be seen as the\n",
    "weighted average of the mini-batch stochastic gradient’s squared terms from the last 1/(1 − $ \\gamma $) time steps.\n",
    "Therefore, the learning rate of each element in the independent variable will not always decline (or remain\n",
    "unchanged) during iteration.\n",
    "\n",
    "By convention, we will use the objective function f (x) = 0.1x 21 + 2x 22 to observe the iterative trajectory\n",
    "of the independent variable in RMSProp. Recall that in the Adagrad section, when we used Adagrad with\n",
    "a learning rate of 0.4, the independent variable moved less in later stages of iteration. However, at the\n",
    "same learning rate, RMSProp can approach the optimal solution faster."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch 20, x1 -0.010599, x2 0.000000\n"
     ]
    },
    {
     "data": {
      "image/svg+xml": [
       "<?xml version=\"1.0\" encoding=\"utf-8\" standalone=\"no\"?>\n",
       "<!DOCTYPE svg PUBLIC \"-//W3C//DTD SVG 1.1//EN\"\n",
       "  \"http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd\">\n",
       "<!-- Created with matplotlib (https://matplotlib.org/) -->\n",
       "<svg height=\"184.455469pt\" version=\"1.1\" viewBox=\"0 0 245.120313 184.455469\" width=\"245.120313pt\" xmlns=\"http://www.w3.org/2000/svg\" xmlns:xlink=\"http://www.w3.org/1999/xlink\">\n",
       " <defs>\n",
       "  <style type=\"text/css\">\n",
       "*{stroke-linecap:butt;stroke-linejoin:round;white-space:pre;}\n",
       "  </style>\n",
       " </defs>\n",
       " <g id=\"figure_1\">\n",
       "  <g id=\"patch_1\">\n",
       "   <path d=\"M 0 184.455469 \n",
       "L 245.120313 184.455469 \n",
       "L 245.120313 0 \n",
       "L 0 0 \n",
       "z\n",
       "\" style=\"fill:none;\"/>\n",
       "  </g>\n",
       "  <g id=\"axes_1\">\n",
       "   <g id=\"patch_2\">\n",
       "    <path d=\"M 42.620312 146.899219 \n",
       "L 237.920313 146.899219 \n",
       "L 237.920313 10.999219 \n",
       "L 42.620312 10.999219 \n",
       "z\n",
       "\" style=\"fill:#ffffff;\"/>\n",
       "   </g>\n",
       "   <g id=\"matplotlib.axis_1\">\n",
       "    <g id=\"xtick_1\">\n",
       "     <g id=\"line2d_1\">\n",
       "      <defs>\n",
       "       <path d=\"M 0 0 \n",
       "L 0 3.5 \n",
       "\" id=\"m6ad9f9a0d8\" style=\"stroke:#000000;stroke-width:0.8;\"/>\n",
       "      </defs>\n",
       "      <g>\n",
       "       <use style=\"stroke:#000000;stroke-width:0.8;\" x=\"88.39375\" xlink:href=\"#m6ad9f9a0d8\" y=\"146.899219\"/>\n",
       "      </g>\n",
       "     </g>\n",
       "     <g id=\"text_1\">\n",
       "      <!-- −4 -->\n",
       "      <defs>\n",
       "       <path d=\"M 10.59375 35.5 \n",
       "L 73.1875 35.5 \n",
       "L 73.1875 27.203125 \n",
       "L 10.59375 27.203125 \n",
       "z\n",
       "\" id=\"DejaVuSans-8722\"/>\n",
       "       <path d=\"M 37.796875 64.3125 \n",
       "L 12.890625 25.390625 \n",
       "L 37.796875 25.390625 \n",
       "z\n",
       "M 35.203125 72.90625 \n",
       "L 47.609375 72.90625 \n",
       "L 47.609375 25.390625 \n",
       "L 58.015625 25.390625 \n",
       "L 58.015625 17.1875 \n",
       "L 47.609375 17.1875 \n",
       "L 47.609375 0 \n",
       "L 37.796875 0 \n",
       "L 37.796875 17.1875 \n",
       "L 4.890625 17.1875 \n",
       "L 4.890625 26.703125 \n",
       "z\n",
       "\" id=\"DejaVuSans-52\"/>\n",
       "      </defs>\n",
       "      <g transform=\"translate(81.022656 161.497656)scale(0.1 -0.1)\">\n",
       "       <use xlink:href=\"#DejaVuSans-8722\"/>\n",
       "       <use x=\"83.789062\" xlink:href=\"#DejaVuSans-52\"/>\n",
       "      </g>\n",
       "     </g>\n",
       "    </g>\n",
       "    <g id=\"xtick_2\">\n",
       "     <g id=\"line2d_2\">\n",
       "      <g>\n",
       "       <use style=\"stroke:#000000;stroke-width:0.8;\" x=\"149.425\" xlink:href=\"#m6ad9f9a0d8\" y=\"146.899219\"/>\n",
       "      </g>\n",
       "     </g>\n",
       "     <g id=\"text_2\">\n",
       "      <!-- −2 -->\n",
       "      <defs>\n",
       "       <path d=\"M 19.1875 8.296875 \n",
       "L 53.609375 8.296875 \n",
       "L 53.609375 0 \n",
       "L 7.328125 0 \n",
       "L 7.328125 8.296875 \n",
       "Q 12.9375 14.109375 22.625 23.890625 \n",
       "Q 32.328125 33.6875 34.8125 36.53125 \n",
       "Q 39.546875 41.84375 41.421875 45.53125 \n",
       "Q 43.3125 49.21875 43.3125 52.78125 \n",
       "Q 43.3125 58.59375 39.234375 62.25 \n",
       "Q 35.15625 65.921875 28.609375 65.921875 \n",
       "Q 23.96875 65.921875 18.8125 64.3125 \n",
       "Q 13.671875 62.703125 7.8125 59.421875 \n",
       "L 7.8125 69.390625 \n",
       "Q 13.765625 71.78125 18.9375 73 \n",
       "Q 24.125 74.21875 28.421875 74.21875 \n",
       "Q 39.75 74.21875 46.484375 68.546875 \n",
       "Q 53.21875 62.890625 53.21875 53.421875 \n",
       "Q 53.21875 48.921875 51.53125 44.890625 \n",
       "Q 49.859375 40.875 45.40625 35.40625 \n",
       "Q 44.1875 33.984375 37.640625 27.21875 \n",
       "Q 31.109375 20.453125 19.1875 8.296875 \n",
       "z\n",
       "\" id=\"DejaVuSans-50\"/>\n",
       "      </defs>\n",
       "      <g transform=\"translate(142.053906 161.497656)scale(0.1 -0.1)\">\n",
       "       <use xlink:href=\"#DejaVuSans-8722\"/>\n",
       "       <use x=\"83.789062\" xlink:href=\"#DejaVuSans-50\"/>\n",
       "      </g>\n",
       "     </g>\n",
       "    </g>\n",
       "    <g id=\"xtick_3\">\n",
       "     <g id=\"line2d_3\">\n",
       "      <g>\n",
       "       <use style=\"stroke:#000000;stroke-width:0.8;\" x=\"210.45625\" xlink:href=\"#m6ad9f9a0d8\" y=\"146.899219\"/>\n",
       "      </g>\n",
       "     </g>\n",
       "     <g id=\"text_3\">\n",
       "      <!-- 0 -->\n",
       "      <defs>\n",
       "       <path d=\"M 31.78125 66.40625 \n",
       "Q 24.171875 66.40625 20.328125 58.90625 \n",
       "Q 16.5 51.421875 16.5 36.375 \n",
       "Q 16.5 21.390625 20.328125 13.890625 \n",
       "Q 24.171875 6.390625 31.78125 6.390625 \n",
       "Q 39.453125 6.390625 43.28125 13.890625 \n",
       "Q 47.125 21.390625 47.125 36.375 \n",
       "Q 47.125 51.421875 43.28125 58.90625 \n",
       "Q 39.453125 66.40625 31.78125 66.40625 \n",
       "z\n",
       "M 31.78125 74.21875 \n",
       "Q 44.046875 74.21875 50.515625 64.515625 \n",
       "Q 56.984375 54.828125 56.984375 36.375 \n",
       "Q 56.984375 17.96875 50.515625 8.265625 \n",
       "Q 44.046875 -1.421875 31.78125 -1.421875 \n",
       "Q 19.53125 -1.421875 13.0625 8.265625 \n",
       "Q 6.59375 17.96875 6.59375 36.375 \n",
       "Q 6.59375 54.828125 13.0625 64.515625 \n",
       "Q 19.53125 74.21875 31.78125 74.21875 \n",
       "z\n",
       "\" id=\"DejaVuSans-48\"/>\n",
       "      </defs>\n",
       "      <g transform=\"translate(207.275 161.497656)scale(0.1 -0.1)\">\n",
       "       <use xlink:href=\"#DejaVuSans-48\"/>\n",
       "      </g>\n",
       "     </g>\n",
       "    </g>\n",
       "    <g id=\"text_4\">\n",
       "     <!-- x1 -->\n",
       "     <defs>\n",
       "      <path d=\"M 54.890625 54.6875 \n",
       "L 35.109375 28.078125 \n",
       "L 55.90625 0 \n",
       "L 45.3125 0 \n",
       "L 29.390625 21.484375 \n",
       "L 13.484375 0 \n",
       "L 2.875 0 \n",
       "L 24.125 28.609375 \n",
       "L 4.6875 54.6875 \n",
       "L 15.28125 54.6875 \n",
       "L 29.78125 35.203125 \n",
       "L 44.28125 54.6875 \n",
       "z\n",
       "\" id=\"DejaVuSans-120\"/>\n",
       "      <path d=\"M 12.40625 8.296875 \n",
       "L 28.515625 8.296875 \n",
       "L 28.515625 63.921875 \n",
       "L 10.984375 60.40625 \n",
       "L 10.984375 69.390625 \n",
       "L 28.421875 72.90625 \n",
       "L 38.28125 72.90625 \n",
       "L 38.28125 8.296875 \n",
       "L 54.390625 8.296875 \n",
       "L 54.390625 0 \n",
       "L 12.40625 0 \n",
       "z\n",
       "\" id=\"DejaVuSans-49\"/>\n",
       "     </defs>\n",
       "     <g transform=\"translate(134.129687 175.175781)scale(0.1 -0.1)\">\n",
       "      <use xlink:href=\"#DejaVuSans-120\"/>\n",
       "      <use x=\"59.179688\" xlink:href=\"#DejaVuSans-49\"/>\n",
       "     </g>\n",
       "    </g>\n",
       "   </g>\n",
       "   <g id=\"matplotlib.axis_2\">\n",
       "    <g id=\"ytick_1\">\n",
       "     <g id=\"line2d_4\">\n",
       "      <defs>\n",
       "       <path d=\"M 0 0 \n",
       "L -3.5 0 \n",
       "\" id=\"m09fedc6792\" style=\"stroke:#000000;stroke-width:0.8;\"/>\n",
       "      </defs>\n",
       "      <g>\n",
       "       <use style=\"stroke:#000000;stroke-width:0.8;\" x=\"42.620312\" xlink:href=\"#m09fedc6792\" y=\"146.899219\"/>\n",
       "      </g>\n",
       "     </g>\n",
       "     <g id=\"text_5\">\n",
       "      <!-- −3 -->\n",
       "      <defs>\n",
       "       <path d=\"M 40.578125 39.3125 \n",
       "Q 47.65625 37.796875 51.625 33 \n",
       "Q 55.609375 28.21875 55.609375 21.1875 \n",
       "Q 55.609375 10.40625 48.1875 4.484375 \n",
       "Q 40.765625 -1.421875 27.09375 -1.421875 \n",
       "Q 22.515625 -1.421875 17.65625 -0.515625 \n",
       "Q 12.796875 0.390625 7.625 2.203125 \n",
       "L 7.625 11.71875 \n",
       "Q 11.71875 9.328125 16.59375 8.109375 \n",
       "Q 21.484375 6.890625 26.8125 6.890625 \n",
       "Q 36.078125 6.890625 40.9375 10.546875 \n",
       "Q 45.796875 14.203125 45.796875 21.1875 \n",
       "Q 45.796875 27.640625 41.28125 31.265625 \n",
       "Q 36.765625 34.90625 28.71875 34.90625 \n",
       "L 20.21875 34.90625 \n",
       "L 20.21875 43.015625 \n",
       "L 29.109375 43.015625 \n",
       "Q 36.375 43.015625 40.234375 45.921875 \n",
       "Q 44.09375 48.828125 44.09375 54.296875 \n",
       "Q 44.09375 59.90625 40.109375 62.90625 \n",
       "Q 36.140625 65.921875 28.71875 65.921875 \n",
       "Q 24.65625 65.921875 20.015625 65.03125 \n",
       "Q 15.375 64.15625 9.8125 62.3125 \n",
       "L 9.8125 71.09375 \n",
       "Q 15.4375 72.65625 20.34375 73.4375 \n",
       "Q 25.25 74.21875 29.59375 74.21875 \n",
       "Q 40.828125 74.21875 47.359375 69.109375 \n",
       "Q 53.90625 64.015625 53.90625 55.328125 \n",
       "Q 53.90625 49.265625 50.4375 45.09375 \n",
       "Q 46.96875 40.921875 40.578125 39.3125 \n",
       "z\n",
       "\" id=\"DejaVuSans-51\"/>\n",
       "      </defs>\n",
       "      <g transform=\"translate(20.878125 150.698437)scale(0.1 -0.1)\">\n",
       "       <use xlink:href=\"#DejaVuSans-8722\"/>\n",
       "       <use x=\"83.789062\" xlink:href=\"#DejaVuSans-51\"/>\n",
       "      </g>\n",
       "     </g>\n",
       "    </g>\n",
       "    <g id=\"ytick_2\">\n",
       "     <g id=\"line2d_5\">\n",
       "      <g>\n",
       "       <use style=\"stroke:#000000;stroke-width:0.8;\" x=\"42.620312\" xlink:href=\"#m09fedc6792\" y=\"112.924219\"/>\n",
       "      </g>\n",
       "     </g>\n",
       "     <g id=\"text_6\">\n",
       "      <!-- −2 -->\n",
       "      <g transform=\"translate(20.878125 116.723437)scale(0.1 -0.1)\">\n",
       "       <use xlink:href=\"#DejaVuSans-8722\"/>\n",
       "       <use x=\"83.789062\" xlink:href=\"#DejaVuSans-50\"/>\n",
       "      </g>\n",
       "     </g>\n",
       "    </g>\n",
       "    <g id=\"ytick_3\">\n",
       "     <g id=\"line2d_6\">\n",
       "      <g>\n",
       "       <use style=\"stroke:#000000;stroke-width:0.8;\" x=\"42.620312\" xlink:href=\"#m09fedc6792\" y=\"78.949219\"/>\n",
       "      </g>\n",
       "     </g>\n",
       "     <g id=\"text_7\">\n",
       "      <!-- −1 -->\n",
       "      <g transform=\"translate(20.878125 82.748437)scale(0.1 -0.1)\">\n",
       "       <use xlink:href=\"#DejaVuSans-8722\"/>\n",
       "       <use x=\"83.789062\" xlink:href=\"#DejaVuSans-49\"/>\n",
       "      </g>\n",
       "     </g>\n",
       "    </g>\n",
       "    <g id=\"ytick_4\">\n",
       "     <g id=\"line2d_7\">\n",
       "      <g>\n",
       "       <use style=\"stroke:#000000;stroke-width:0.8;\" x=\"42.620312\" xlink:href=\"#m09fedc6792\" y=\"44.974219\"/>\n",
       "      </g>\n",
       "     </g>\n",
       "     <g id=\"text_8\">\n",
       "      <!-- 0 -->\n",
       "      <g transform=\"translate(29.257812 48.773437)scale(0.1 -0.1)\">\n",
       "       <use xlink:href=\"#DejaVuSans-48\"/>\n",
       "      </g>\n",
       "     </g>\n",
       "    </g>\n",
       "    <g id=\"ytick_5\">\n",
       "     <g id=\"line2d_8\">\n",
       "      <g>\n",
       "       <use style=\"stroke:#000000;stroke-width:0.8;\" x=\"42.620312\" xlink:href=\"#m09fedc6792\" y=\"10.999219\"/>\n",
       "      </g>\n",
       "     </g>\n",
       "     <g id=\"text_9\">\n",
       "      <!-- 1 -->\n",
       "      <g transform=\"translate(29.257812 14.798437)scale(0.1 -0.1)\">\n",
       "       <use xlink:href=\"#DejaVuSans-49\"/>\n",
       "      </g>\n",
       "     </g>\n",
       "    </g>\n",
       "    <g id=\"text_10\">\n",
       "     <!-- x2 -->\n",
       "     <g transform=\"translate(14.798437 85.089844)rotate(-90)scale(0.1 -0.1)\">\n",
       "      <use xlink:href=\"#DejaVuSans-120\"/>\n",
       "      <use x=\"59.179688\" xlink:href=\"#DejaVuSans-50\"/>\n",
       "     </g>\n",
       "    </g>\n",
       "   </g>\n",
       "   <g id=\"LineCollection_1\"/>\n",
       "   <g id=\"LineCollection_2\">\n",
       "    <path clip-path=\"url(#p8674f2c5ad)\" d=\"M 237.920313 86.009224 \n",
       "L 234.86875 86.124739 \n",
       "L 231.817188 86.226664 \n",
       "L 228.765625 86.314999 \n",
       "L 225.714063 86.389744 \n",
       "L 222.6625 86.450899 \n",
       "L 219.610938 86.498464 \n",
       "L 216.559375 86.532439 \n",
       "L 213.507813 86.552824 \n",
       "L 210.45625 86.559619 \n",
       "L 207.404688 86.552824 \n",
       "L 204.353125 86.532439 \n",
       "L 201.301563 86.498464 \n",
       "L 198.25 86.450899 \n",
       "L 195.198438 86.389744 \n",
       "L 192.146875 86.314999 \n",
       "L 189.095313 86.226664 \n",
       "L 186.04375 86.124739 \n",
       "L 182.992188 86.009224 \n",
       "L 179.940625 85.880119 \n",
       "L 177.034375 85.744219 \n",
       "L 176.889063 85.736833 \n",
       "L 173.8375 85.566958 \n",
       "L 170.785938 85.382311 \n",
       "L 167.734375 85.182893 \n",
       "L 164.682813 84.968702 \n",
       "L 161.63125 84.73974 \n",
       "L 158.579688 84.496007 \n",
       "L 155.528125 84.237501 \n",
       "L 152.476563 83.964224 \n",
       "L 149.425 83.676175 \n",
       "L 146.373438 83.373355 \n",
       "L 143.321875 83.055762 \n",
       "L 140.270313 82.723398 \n",
       "L 137.21875 82.376262 \n",
       "L 136.969643 82.346719 \n",
       "L 134.167188 81.982701 \n",
       "L 131.115625 81.570147 \n",
       "L 128.064063 81.141415 \n",
       "L 125.0125 80.696504 \n",
       "L 121.960938 80.235415 \n",
       "L 118.909375 79.758147 \n",
       "L 115.857813 79.264701 \n",
       "L 113.96875 78.949219 \n",
       "L 112.80625 78.73464 \n",
       "L 109.754688 78.153488 \n",
       "L 106.703125 77.554456 \n",
       "L 103.651563 76.937541 \n",
       "L 100.6 76.302745 \n",
       "L 97.548438 75.650067 \n",
       "L 97.100875 75.551719 \n",
       "L 94.496875 74.912189 \n",
       "L 91.445313 74.142756 \n",
       "L 88.39375 73.353336 \n",
       "L 85.342188 72.543932 \n",
       "L 83.908321 72.154219 \n",
       "L 82.290625 71.655919 \n",
       "L 79.239063 70.693294 \n",
       "L 76.1875 69.708019 \n",
       "L 73.307374 68.756719 \n",
       "L 73.135938 68.691382 \n",
       "L 70.084375 67.502257 \n",
       "L 67.032813 66.286998 \n",
       "L 64.752171 65.359219 \n",
       "L 63.98125 64.988582 \n",
       "L 60.929688 63.490594 \n",
       "L 57.878125 61.961719 \n",
       "L 57.878125 61.961719 \n",
       "L 54.826563 60.055344 \n",
       "L 52.486044 58.564219 \n",
       "L 51.775 57.98179 \n",
       "L 48.723438 55.433665 \n",
       "L 48.409725 55.166719 \n",
       "L 45.671875 51.905119 \n",
       "L 45.559891 51.769219 \n",
       "L 43.880132 48.371719 \n",
       "L 43.320212 44.974219 \n",
       "L 43.880132 41.576719 \n",
       "L 45.559891 38.179219 \n",
       "L 45.671875 38.043319 \n",
       "L 48.409725 34.781719 \n",
       "L 48.723438 34.514772 \n",
       "L 51.775 31.966647 \n",
       "L 52.486044 31.384219 \n",
       "L 54.826563 29.893094 \n",
       "L 57.878125 27.986719 \n",
       "L 57.878125 27.986719 \n",
       "L 60.929688 26.457844 \n",
       "L 63.98125 24.959855 \n",
       "L 64.752171 24.589219 \n",
       "L 67.032813 23.66144 \n",
       "L 70.084375 22.44618 \n",
       "L 73.135938 21.257055 \n",
       "L 73.307374 21.191719 \n",
       "L 76.1875 20.240419 \n",
       "L 79.239063 19.255144 \n",
       "L 82.290625 18.292519 \n",
       "L 83.908321 17.794219 \n",
       "L 85.342188 17.404506 \n",
       "L 88.39375 16.595101 \n",
       "L 91.445313 15.805682 \n",
       "L 94.496875 15.036248 \n",
       "L 97.100875 14.396719 \n",
       "L 97.548438 14.29837 \n",
       "L 100.6 13.645692 \n",
       "L 103.651563 13.010896 \n",
       "L 106.703125 12.393982 \n",
       "L 109.754688 11.794949 \n",
       "L 112.80625 11.213798 \n",
       "L 113.96875 10.999219 \n",
       "\" style=\"fill:none;stroke:#1f77b4;stroke-width:1.5;\"/>\n",
       "   </g>\n",
       "   <g id=\"LineCollection_3\">\n",
       "    <path clip-path=\"url(#p8674f2c5ad)\" d=\"M 237.920313 103.406365 \n",
       "L 234.86875 103.488876 \n",
       "L 231.817188 103.561679 \n",
       "L 228.765625 103.624776 \n",
       "L 225.714063 103.678165 \n",
       "L 222.6625 103.721847 \n",
       "L 219.610938 103.755822 \n",
       "L 216.559375 103.78009 \n",
       "L 213.507813 103.794651 \n",
       "L 210.45625 103.799504 \n",
       "L 207.404688 103.794651 \n",
       "L 204.353125 103.78009 \n",
       "L 201.301563 103.755822 \n",
       "L 198.25 103.721847 \n",
       "L 195.198438 103.678165 \n",
       "L 192.146875 103.624776 \n",
       "L 189.095313 103.561679 \n",
       "L 186.04375 103.488876 \n",
       "L 182.992188 103.406365 \n",
       "L 179.940625 103.314147 \n",
       "L 176.889063 103.212222 \n",
       "L 173.8375 103.10059 \n",
       "L 170.785938 102.979251 \n",
       "L 167.734375 102.848204 \n",
       "L 165.208944 102.731719 \n",
       "L 164.682813 102.70598 \n",
       "L 161.63125 102.546401 \n",
       "L 158.579688 102.376526 \n",
       "L 155.528125 102.196355 \n",
       "L 152.476563 102.005889 \n",
       "L 149.425 101.805128 \n",
       "L 146.373438 101.594071 \n",
       "L 143.321875 101.372719 \n",
       "L 140.270313 101.141071 \n",
       "L 137.21875 100.899128 \n",
       "L 134.167188 100.646889 \n",
       "L 131.115625 100.384355 \n",
       "L 128.064063 100.111526 \n",
       "L 125.0125 99.828401 \n",
       "L 121.960938 99.53498 \n",
       "L 119.943803 99.334219 \n",
       "L 118.909375 99.224622 \n",
       "L 115.857813 98.890352 \n",
       "L 112.80625 98.545122 \n",
       "L 109.754688 98.188932 \n",
       "L 106.703125 97.821783 \n",
       "L 103.651563 97.443674 \n",
       "L 100.6 97.054606 \n",
       "L 97.548438 96.654578 \n",
       "L 94.496875 96.24359 \n",
       "L 92.277557 95.936719 \n",
       "L 91.445313 95.813706 \n",
       "L 88.39375 95.350943 \n",
       "L 85.342188 94.876464 \n",
       "L 82.290625 94.39027 \n",
       "L 79.239063 93.892361 \n",
       "L 76.1875 93.382736 \n",
       "L 73.135938 92.861395 \n",
       "L 71.291587 92.539219 \n",
       "L 70.084375 92.312719 \n",
       "L 67.032813 91.727594 \n",
       "L 63.98125 91.129885 \n",
       "L 60.929688 90.519594 \n",
       "L 57.878125 89.896719 \n",
       "L 54.826563 89.26126 \n",
       "L 54.263653 89.141719 \n",
       "L 51.775 88.570939 \n",
       "L 48.723438 87.857464 \n",
       "L 45.671875 87.130399 \n",
       "L 42.620313 86.389744 \n",
       "\" style=\"fill:none;stroke:#1f77b4;stroke-width:1.5;\"/>\n",
       "   </g>\n",
       "   <g id=\"LineCollection_4\">\n",
       "    <path clip-path=\"url(#p8674f2c5ad)\" d=\"M 237.920313 116.712826 \n",
       "L 234.86875 116.779986 \n",
       "L 231.817188 116.839245 \n",
       "L 228.765625 116.890602 \n",
       "L 225.714063 116.934059 \n",
       "L 222.6625 116.969614 \n",
       "L 219.610938 116.997268 \n",
       "L 216.559375 117.017021 \n",
       "L 213.507813 117.028873 \n",
       "L 210.45625 117.032823 \n",
       "L 207.404688 117.028873 \n",
       "L 204.353125 117.017021 \n",
       "L 201.301563 116.997268 \n",
       "L 198.25 116.969614 \n",
       "L 195.198438 116.934059 \n",
       "L 192.146875 116.890602 \n",
       "L 189.095313 116.839245 \n",
       "L 186.04375 116.779986 \n",
       "L 182.992188 116.712826 \n",
       "L 179.940625 116.637765 \n",
       "L 176.889063 116.554803 \n",
       "L 173.8375 116.46394 \n",
       "L 170.785938 116.365175 \n",
       "L 169.542708 116.321719 \n",
       "L 167.734375 116.255426 \n",
       "L 164.682813 116.135271 \n",
       "L 161.63125 116.006829 \n",
       "L 158.579688 115.8701 \n",
       "L 155.528125 115.725085 \n",
       "L 152.476563 115.571783 \n",
       "L 149.425 115.410194 \n",
       "L 146.373438 115.240319 \n",
       "L 143.321875 115.062158 \n",
       "L 140.270313 114.87571 \n",
       "L 137.21875 114.680975 \n",
       "L 134.167188 114.477954 \n",
       "L 131.115625 114.266646 \n",
       "L 128.064063 114.047051 \n",
       "L 125.0125 113.81917 \n",
       "L 121.960938 113.583002 \n",
       "L 118.909375 113.338548 \n",
       "L 115.857813 113.085807 \n",
       "L 113.96875 112.924219 \n",
       "L 112.80625 112.81968 \n",
       "L 109.754688 112.536555 \n",
       "L 106.703125 112.244719 \n",
       "L 103.651563 111.944171 \n",
       "L 100.6 111.634911 \n",
       "L 97.548438 111.31694 \n",
       "L 94.496875 110.990257 \n",
       "L 91.445313 110.654863 \n",
       "L 88.39375 110.310757 \n",
       "L 85.342188 109.95794 \n",
       "L 82.290625 109.596411 \n",
       "L 81.716213 109.526719 \n",
       "L 79.239063 109.209925 \n",
       "L 76.1875 108.810489 \n",
       "L 73.135938 108.401871 \n",
       "L 70.084375 107.98407 \n",
       "L 67.032813 107.557087 \n",
       "L 63.98125 107.120921 \n",
       "L 60.929688 106.675573 \n",
       "L 57.878125 106.221043 \n",
       "L 57.273855 106.129219 \n",
       "L 54.826563 105.736079 \n",
       "L 51.775 105.236162 \n",
       "L 48.723438 104.726537 \n",
       "L 45.671875 104.207204 \n",
       "L 42.620313 103.678165 \n",
       "\" style=\"fill:none;stroke:#1f77b4;stroke-width:1.5;\"/>\n",
       "   </g>\n",
       "   <g id=\"LineCollection_5\">\n",
       "    <path clip-path=\"url(#p8674f2c5ad)\" d=\"M 237.920313 127.897487 \n",
       "L 234.86875 127.956423 \n",
       "L 231.817188 128.008425 \n",
       "L 228.765625 128.053494 \n",
       "L 225.714063 128.091629 \n",
       "L 222.6625 128.122831 \n",
       "L 219.610938 128.147099 \n",
       "L 216.559375 128.164433 \n",
       "L 213.507813 128.174834 \n",
       "L 210.45625 128.1783 \n",
       "L 207.404688 128.174834 \n",
       "L 204.353125 128.164433 \n",
       "L 201.301563 128.147099 \n",
       "L 198.25 128.122831 \n",
       "L 195.198438 128.091629 \n",
       "L 192.146875 128.053494 \n",
       "L 189.095313 128.008425 \n",
       "L 186.04375 127.956423 \n",
       "L 182.992188 127.897487 \n",
       "L 179.940625 127.831617 \n",
       "L 176.889063 127.758813 \n",
       "L 173.8375 127.679076 \n",
       "L 170.785938 127.592405 \n",
       "L 167.734375 127.4988 \n",
       "L 164.682813 127.398262 \n",
       "L 161.63125 127.29079 \n",
       "L 158.579688 127.176385 \n",
       "L 155.528125 127.055045 \n",
       "L 152.476563 126.926772 \n",
       "L 149.425 126.791566 \n",
       "L 146.373438 126.649425 \n",
       "L 143.605741 126.514219 \n",
       "L 143.321875 126.499761 \n",
       "L 140.270313 126.337115 \n",
       "L 137.21875 126.16724 \n",
       "L 134.167188 125.990136 \n",
       "L 131.115625 125.805804 \n",
       "L 128.064063 125.614243 \n",
       "L 125.0125 125.415453 \n",
       "L 121.960938 125.209434 \n",
       "L 118.909375 124.996187 \n",
       "L 115.857813 124.775711 \n",
       "L 112.80625 124.548006 \n",
       "L 109.754688 124.313072 \n",
       "L 106.703125 124.07091 \n",
       "L 103.651563 123.821519 \n",
       "L 100.6 123.5649 \n",
       "L 97.548438 123.301051 \n",
       "L 95.473375 123.116719 \n",
       "L 94.496875 123.026119 \n",
       "L 91.445313 122.735444 \n",
       "L 88.39375 122.437219 \n",
       "L 85.342188 122.131444 \n",
       "L 82.290625 121.818119 \n",
       "L 79.239063 121.497244 \n",
       "L 76.1875 121.168819 \n",
       "L 73.135938 120.832844 \n",
       "L 70.084375 120.489319 \n",
       "L 67.032813 120.138244 \n",
       "L 63.98125 119.779619 \n",
       "L 63.477899 119.719219 \n",
       "L 60.929688 119.399222 \n",
       "L 57.878125 119.008114 \n",
       "L 54.826563 118.609105 \n",
       "L 51.775 118.202195 \n",
       "L 48.723438 117.787384 \n",
       "L 45.671875 117.364672 \n",
       "L 42.620313 116.934059 \n",
       "\" style=\"fill:none;stroke:#1f77b4;stroke-width:1.5;\"/>\n",
       "   </g>\n",
       "   <g id=\"LineCollection_6\">\n",
       "    <path clip-path=\"url(#p8674f2c5ad)\" d=\"M 237.920313 137.753766 \n",
       "L 234.86875 137.806273 \n",
       "L 231.817188 137.852603 \n",
       "L 228.765625 137.892755 \n",
       "L 225.714063 137.92673 \n",
       "L 222.6625 137.954528 \n",
       "L 219.610938 137.976148 \n",
       "L 216.559375 137.991591 \n",
       "L 213.507813 138.000857 \n",
       "L 210.45625 138.003946 \n",
       "L 207.404688 138.000857 \n",
       "L 204.353125 137.991591 \n",
       "L 201.301563 137.976148 \n",
       "L 198.25 137.954528 \n",
       "L 195.198438 137.92673 \n",
       "L 192.146875 137.892755 \n",
       "L 189.095313 137.852603 \n",
       "L 186.04375 137.806273 \n",
       "L 182.992188 137.753766 \n",
       "L 179.940625 137.695082 \n",
       "L 176.889063 137.630221 \n",
       "L 173.8375 137.559182 \n",
       "L 170.785938 137.481966 \n",
       "L 167.734375 137.398573 \n",
       "L 164.682813 137.309003 \n",
       "L 161.63125 137.213255 \n",
       "L 158.579688 137.11133 \n",
       "L 155.528125 137.003228 \n",
       "L 152.476563 136.888948 \n",
       "L 149.425 136.768491 \n",
       "L 147.936433 136.706719 \n",
       "L 146.373438 136.63941 \n",
       "L 143.321875 136.501587 \n",
       "L 140.270313 136.357353 \n",
       "L 137.21875 136.206709 \n",
       "L 134.167188 136.049655 \n",
       "L 131.115625 135.88619 \n",
       "L 128.064063 135.716315 \n",
       "L 125.0125 135.54003 \n",
       "L 121.960938 135.357334 \n",
       "L 118.909375 135.168228 \n",
       "L 115.857813 134.972712 \n",
       "L 112.80625 134.770785 \n",
       "L 109.754688 134.562448 \n",
       "L 106.703125 134.3477 \n",
       "L 103.651563 134.126542 \n",
       "L 100.6 133.898973 \n",
       "L 97.548438 133.664995 \n",
       "L 94.496875 133.424606 \n",
       "L 93.07017 133.309219 \n",
       "L 91.445313 133.172653 \n",
       "L 88.39375 132.909513 \n",
       "L 85.342188 132.639711 \n",
       "L 82.290625 132.363248 \n",
       "L 79.239063 132.080123 \n",
       "L 76.1875 131.790336 \n",
       "L 73.135938 131.493888 \n",
       "L 70.084375 131.190778 \n",
       "L 67.032813 130.881006 \n",
       "L 63.98125 130.564572 \n",
       "L 60.929688 130.241476 \n",
       "L 57.878125 129.911719 \n",
       "L 57.878125 129.911719 \n",
       "L 54.826563 129.561568 \n",
       "L 51.775 129.204484 \n",
       "L 48.723438 128.840466 \n",
       "L 45.671875 128.469515 \n",
       "L 42.620313 128.091629 \n",
       "\" style=\"fill:none;stroke:#1f77b4;stroke-width:1.5;\"/>\n",
       "   </g>\n",
       "   <g id=\"LineCollection_7\">\n",
       "    <path clip-path=\"url(#p8674f2c5ad)\" d=\"M 210.45625 146.899219 \n",
       "L 207.404688 146.89634 \n",
       "L 204.353125 146.887702 \n",
       "L 201.301563 146.873306 \n",
       "L 198.25 146.853151 \n",
       "L 195.198438 146.827238 \n",
       "L 192.146875 146.795566 \n",
       "L 189.095313 146.758136 \n",
       "L 186.04375 146.714948 \n",
       "L 182.992188 146.666001 \n",
       "L 179.940625 146.611295 \n",
       "L 176.889063 146.550831 \n",
       "L 173.8375 146.484609 \n",
       "L 170.785938 146.412628 \n",
       "L 167.734375 146.334888 \n",
       "L 164.682813 146.25139 \n",
       "L 161.63125 146.162134 \n",
       "L 158.579688 146.067119 \n",
       "L 155.528125 145.966346 \n",
       "L 152.476563 145.859814 \n",
       "L 149.425 145.747524 \n",
       "L 146.373438 145.629475 \n",
       "L 143.321875 145.505668 \n",
       "L 140.270313 145.376102 \n",
       "L 137.21875 145.240778 \n",
       "L 134.167188 145.099695 \n",
       "L 131.115625 144.952854 \n",
       "L 128.064063 144.800255 \n",
       "L 125.0125 144.641897 \n",
       "L 121.960938 144.47778 \n",
       "L 118.909375 144.307905 \n",
       "L 115.857813 144.132272 \n",
       "L 112.80625 143.95088 \n",
       "L 109.754688 143.763729 \n",
       "L 106.703125 143.57082 \n",
       "L 105.641712 143.501719 \n",
       "L 103.651563 143.367607 \n",
       "L 100.6 143.156008 \n",
       "L 97.548438 142.938449 \n",
       "L 94.496875 142.714929 \n",
       "L 91.445313 142.485449 \n",
       "L 88.39375 142.250008 \n",
       "L 85.342188 142.008607 \n",
       "L 82.290625 141.761245 \n",
       "L 79.239063 141.507923 \n",
       "L 76.1875 141.24864 \n",
       "L 73.135938 140.983396 \n",
       "L 70.084375 140.712192 \n",
       "L 67.032813 140.435028 \n",
       "L 63.98125 140.151903 \n",
       "L 63.477899 140.104219 \n",
       "L 60.929688 139.854039 \n",
       "L 57.878125 139.548264 \n",
       "L 54.826563 139.236312 \n",
       "L 51.775 138.918182 \n",
       "L 48.723438 138.593876 \n",
       "L 45.671875 138.263391 \n",
       "L 42.620313 137.92673 \n",
       "\" style=\"fill:none;stroke:#1f77b4;stroke-width:1.5;\"/>\n",
       "    <path clip-path=\"url(#p8674f2c5ad)\" d=\"M 237.920313 146.666001 \n",
       "L 234.86875 146.714948 \n",
       "L 231.817188 146.758136 \n",
       "L 228.765625 146.795566 \n",
       "L 225.714063 146.827238 \n",
       "L 222.6625 146.853151 \n",
       "L 219.610938 146.873306 \n",
       "L 216.559375 146.887702 \n",
       "L 213.507813 146.89634 \n",
       "L 210.45625 146.899219 \n",
       "\" style=\"fill:none;stroke:#1f77b4;stroke-width:1.5;\"/>\n",
       "   </g>\n",
       "   <g id=\"LineCollection_8\">\n",
       "    <path clip-path=\"url(#p8674f2c5ad)\" d=\"M 43.320212 146.899219 \n",
       "L 42.620313 146.827238 \n",
       "\" style=\"fill:none;stroke:#1f77b4;stroke-width:1.5;\"/>\n",
       "   </g>\n",
       "   <g id=\"LineCollection_9\"/>\n",
       "   <g id=\"line2d_9\">\n",
       "    <path clip-path=\"url(#p8674f2c5ad)\" d=\"M 57.878125 112.924219 \n",
       "L 96.477484 69.948869 \n",
       "L 120.357142 54.42355 \n",
       "L 138.043261 48.294977 \n",
       "L 152.011133 46.027885 \n",
       "L 163.368487 45.269751 \n",
       "L 172.726061 45.04561 \n",
       "L 180.468975 44.988536 \n",
       "L 186.866078 44.976471 \n",
       "L 192.121325 44.97447 \n",
       "L 196.40025 44.974235 \n",
       "L 199.844228 44.974219 \n",
       "L 202.578131 44.974219 \n",
       "L 204.714223 44.974219 \n",
       "L 206.353865 44.974219 \n",
       "L 207.588039 44.974219 \n",
       "L 208.497343 44.974219 \n",
       "L 209.15187 44.974219 \n",
       "L 209.611243 44.974219 \n",
       "L 209.924922 44.974219 \n",
       "L 210.132825 44.974219 \n",
       "\" style=\"fill:none;stroke:#ff7f0e;stroke-linecap:square;stroke-width:1.5;\"/>\n",
       "    <defs>\n",
       "     <path d=\"M 0 3 \n",
       "C 0.795609 3 1.55874 2.683901 2.12132 2.12132 \n",
       "C 2.683901 1.55874 3 0.795609 3 0 \n",
       "C 3 -0.795609 2.683901 -1.55874 2.12132 -2.12132 \n",
       "C 1.55874 -2.683901 0.795609 -3 0 -3 \n",
       "C -0.795609 -3 -1.55874 -2.683901 -2.12132 -2.12132 \n",
       "C -2.683901 -1.55874 -3 -0.795609 -3 0 \n",
       "C -3 0.795609 -2.683901 1.55874 -2.12132 2.12132 \n",
       "C -1.55874 2.683901 -0.795609 3 0 3 \n",
       "z\n",
       "\" id=\"m7dd7e52eab\" style=\"stroke:#ff7f0e;\"/>\n",
       "    </defs>\n",
       "    <g clip-path=\"url(#p8674f2c5ad)\">\n",
       "     <use style=\"fill:#ff7f0e;stroke:#ff7f0e;\" x=\"57.878125\" xlink:href=\"#m7dd7e52eab\" y=\"112.924219\"/>\n",
       "     <use style=\"fill:#ff7f0e;stroke:#ff7f0e;\" x=\"96.477484\" xlink:href=\"#m7dd7e52eab\" y=\"69.948869\"/>\n",
       "     <use style=\"fill:#ff7f0e;stroke:#ff7f0e;\" x=\"120.357142\" xlink:href=\"#m7dd7e52eab\" y=\"54.42355\"/>\n",
       "     <use style=\"fill:#ff7f0e;stroke:#ff7f0e;\" x=\"138.043261\" xlink:href=\"#m7dd7e52eab\" y=\"48.294977\"/>\n",
       "     <use style=\"fill:#ff7f0e;stroke:#ff7f0e;\" x=\"152.011133\" xlink:href=\"#m7dd7e52eab\" y=\"46.027885\"/>\n",
       "     <use style=\"fill:#ff7f0e;stroke:#ff7f0e;\" x=\"163.368487\" xlink:href=\"#m7dd7e52eab\" y=\"45.269751\"/>\n",
       "     <use style=\"fill:#ff7f0e;stroke:#ff7f0e;\" x=\"172.726061\" xlink:href=\"#m7dd7e52eab\" y=\"45.04561\"/>\n",
       "     <use style=\"fill:#ff7f0e;stroke:#ff7f0e;\" x=\"180.468975\" xlink:href=\"#m7dd7e52eab\" y=\"44.988536\"/>\n",
       "     <use style=\"fill:#ff7f0e;stroke:#ff7f0e;\" x=\"186.866078\" xlink:href=\"#m7dd7e52eab\" y=\"44.976471\"/>\n",
       "     <use style=\"fill:#ff7f0e;stroke:#ff7f0e;\" x=\"192.121325\" xlink:href=\"#m7dd7e52eab\" y=\"44.97447\"/>\n",
       "     <use style=\"fill:#ff7f0e;stroke:#ff7f0e;\" x=\"196.40025\" xlink:href=\"#m7dd7e52eab\" y=\"44.974235\"/>\n",
       "     <use style=\"fill:#ff7f0e;stroke:#ff7f0e;\" x=\"199.844228\" xlink:href=\"#m7dd7e52eab\" y=\"44.974219\"/>\n",
       "     <use style=\"fill:#ff7f0e;stroke:#ff7f0e;\" x=\"202.578131\" xlink:href=\"#m7dd7e52eab\" y=\"44.974219\"/>\n",
       "     <use style=\"fill:#ff7f0e;stroke:#ff7f0e;\" x=\"204.714223\" xlink:href=\"#m7dd7e52eab\" y=\"44.974219\"/>\n",
       "     <use style=\"fill:#ff7f0e;stroke:#ff7f0e;\" x=\"206.353865\" xlink:href=\"#m7dd7e52eab\" y=\"44.974219\"/>\n",
       "     <use style=\"fill:#ff7f0e;stroke:#ff7f0e;\" x=\"207.588039\" xlink:href=\"#m7dd7e52eab\" y=\"44.974219\"/>\n",
       "     <use style=\"fill:#ff7f0e;stroke:#ff7f0e;\" x=\"208.497343\" xlink:href=\"#m7dd7e52eab\" y=\"44.974219\"/>\n",
       "     <use style=\"fill:#ff7f0e;stroke:#ff7f0e;\" x=\"209.15187\" xlink:href=\"#m7dd7e52eab\" y=\"44.974219\"/>\n",
       "     <use style=\"fill:#ff7f0e;stroke:#ff7f0e;\" x=\"209.611243\" xlink:href=\"#m7dd7e52eab\" y=\"44.974219\"/>\n",
       "     <use style=\"fill:#ff7f0e;stroke:#ff7f0e;\" x=\"209.924922\" xlink:href=\"#m7dd7e52eab\" y=\"44.974219\"/>\n",
       "     <use style=\"fill:#ff7f0e;stroke:#ff7f0e;\" x=\"210.132825\" xlink:href=\"#m7dd7e52eab\" y=\"44.974219\"/>\n",
       "    </g>\n",
       "   </g>\n",
       "   <g id=\"patch_3\">\n",
       "    <path d=\"M 42.620312 146.899219 \n",
       "L 42.620312 10.999219 \n",
       "\" style=\"fill:none;stroke:#000000;stroke-linecap:square;stroke-linejoin:miter;stroke-width:0.8;\"/>\n",
       "   </g>\n",
       "   <g id=\"patch_4\">\n",
       "    <path d=\"M 237.920313 146.899219 \n",
       "L 237.920313 10.999219 \n",
       "\" style=\"fill:none;stroke:#000000;stroke-linecap:square;stroke-linejoin:miter;stroke-width:0.8;\"/>\n",
       "   </g>\n",
       "   <g id=\"patch_5\">\n",
       "    <path d=\"M 42.620313 146.899219 \n",
       "L 237.920313 146.899219 \n",
       "\" style=\"fill:none;stroke:#000000;stroke-linecap:square;stroke-linejoin:miter;stroke-width:0.8;\"/>\n",
       "   </g>\n",
       "   <g id=\"patch_6\">\n",
       "    <path d=\"M 42.620313 10.999219 \n",
       "L 237.920313 10.999219 \n",
       "\" style=\"fill:none;stroke:#000000;stroke-linecap:square;stroke-linejoin:miter;stroke-width:0.8;\"/>\n",
       "   </g>\n",
       "  </g>\n",
       " </g>\n",
       " <defs>\n",
       "  <clipPath id=\"p8674f2c5ad\">\n",
       "   <rect height=\"135.9\" width=\"195.3\" x=\"42.620312\" y=\"10.999219\"/>\n",
       "  </clipPath>\n",
       " </defs>\n",
       "</svg>\n"
      ],
      "text/plain": [
       "<Figure size 252x180 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "import sys\n",
    "sys.path.insert(0, '..')\n",
    "\n",
    "import d2l\n",
    "import math\n",
    "import torch\n",
    "\n",
    "def rmsprop_2d(x1, x2, s1, s2):\n",
    "    g1, g2, eps = 0.2 * x1, 4 * x2, 1e-6\n",
    "    s1 = gamma * s1 + (1 - gamma) * g1 ** 2\n",
    "    s2 = gamma * s2 + (1 - gamma) * g2 ** 2\n",
    "    x1 -= eta / math.sqrt(s1 + eps) * g1\n",
    "    x2 -= eta / math.sqrt(s2 + eps) * g2\n",
    "    return x1, x2, s1, s2\n",
    "\n",
    "def f_2d(x1, x2):\n",
    "    return 0.1 * x1 ** 2 + 2 * x2 ** 2\n",
    "eta, gamma = 0.4, 0.9\n",
    "d2l.show_trace_2d(f_2d, d2l.train_2d(rmsprop_2d))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 8.6.2 Implementation from Scratch\n",
    "\n",
    "Next, we implement RMSProp with the formula in the algorithm."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "def init_rmsprop_states():\n",
    "    s_w = torch.zeros((features.shape[1], 1))\n",
    "    s_b = torch.zeros(1)\n",
    "    return (s_w, s_b)\n",
    "\n",
    "def rmsprop(params, states, hyperparams):\n",
    "    gamma, eps = hyperparams['gamma'], 1e-6\n",
    "    for p, s in zip(params, states):\n",
    "        s[:] = gamma * s + (1 - gamma) * p.grad**2\n",
    "        p[:] -= hyperparams['lr'] * p.grad / (s + eps).sqrt()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We set the initial learning rate to 0.01 and the hyperparameter $ \\gamma $ to 0.9. Now, the variable $s_t$ can be treated\n",
    "as the weighted average of the square term $g_t$ ⊙ $g_t$ from the last 1/(1 − 0.9) = 10 time steps."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 8.6.3 Concise Implementation\n",
    "\n",
    "From the *Trainer* instance of the algorithm named rmsprop, we can implement the **RMSProp** algorithm\n",
    "with Gluon to train models. Note that the hyperparameter $ \\gamma $ is assigned by *gamma1*."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "loss: 0.242, 0.012 sec/epoch\n"
     ]
    },
    {
     "data": {
      "image/svg+xml": [
       "<?xml version=\"1.0\" encoding=\"utf-8\" standalone=\"no\"?>\n",
       "<!DOCTYPE svg PUBLIC \"-//W3C//DTD SVG 1.1//EN\"\n",
       "  \"http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd\">\n",
       "<!-- Created with matplotlib (https://matplotlib.org/) -->\n",
       "<svg height=\"184.455469pt\" version=\"1.1\" viewBox=\"0 0 266.957813 184.455469\" width=\"266.957813pt\" xmlns=\"http://www.w3.org/2000/svg\" xmlns:xlink=\"http://www.w3.org/1999/xlink\">\n",
       " <defs>\n",
       "  <style type=\"text/css\">\n",
       "*{stroke-linecap:butt;stroke-linejoin:round;white-space:pre;}\n",
       "  </style>\n",
       " </defs>\n",
       " <g id=\"figure_1\">\n",
       "  <g id=\"patch_1\">\n",
       "   <path d=\"M -0 184.455469 \n",
       "L 266.957813 184.455469 \n",
       "L 266.957813 0 \n",
       "L -0 0 \n",
       "z\n",
       "\" style=\"fill:none;\"/>\n",
       "  </g>\n",
       "  <g id=\"axes_1\">\n",
       "   <g id=\"patch_2\">\n",
       "    <path d=\"M 56.50625 146.899219 \n",
       "L 251.80625 146.899219 \n",
       "L 251.80625 10.999219 \n",
       "L 56.50625 10.999219 \n",
       "z\n",
       "\" style=\"fill:#ffffff;\"/>\n",
       "   </g>\n",
       "   <g id=\"matplotlib.axis_1\">\n",
       "    <g id=\"xtick_1\">\n",
       "     <g id=\"line2d_1\">\n",
       "      <path clip-path=\"url(#pbe888d7fac)\" d=\"M 56.50625 146.899219 \n",
       "L 56.50625 10.999219 \n",
       "\" style=\"fill:none;stroke:#b0b0b0;stroke-linecap:square;stroke-width:0.8;\"/>\n",
       "     </g>\n",
       "     <g id=\"line2d_2\">\n",
       "      <defs>\n",
       "       <path d=\"M 0 0 \n",
       "L 0 3.5 \n",
       "\" id=\"mc9a12aad56\" style=\"stroke:#000000;stroke-width:0.8;\"/>\n",
       "      </defs>\n",
       "      <g>\n",
       "       <use style=\"stroke:#000000;stroke-width:0.8;\" x=\"56.50625\" xlink:href=\"#mc9a12aad56\" y=\"146.899219\"/>\n",
       "      </g>\n",
       "     </g>\n",
       "     <g id=\"text_1\">\n",
       "      <!-- 0.0 -->\n",
       "      <defs>\n",
       "       <path d=\"M 31.78125 66.40625 \n",
       "Q 24.171875 66.40625 20.328125 58.90625 \n",
       "Q 16.5 51.421875 16.5 36.375 \n",
       "Q 16.5 21.390625 20.328125 13.890625 \n",
       "Q 24.171875 6.390625 31.78125 6.390625 \n",
       "Q 39.453125 6.390625 43.28125 13.890625 \n",
       "Q 47.125 21.390625 47.125 36.375 \n",
       "Q 47.125 51.421875 43.28125 58.90625 \n",
       "Q 39.453125 66.40625 31.78125 66.40625 \n",
       "z\n",
       "M 31.78125 74.21875 \n",
       "Q 44.046875 74.21875 50.515625 64.515625 \n",
       "Q 56.984375 54.828125 56.984375 36.375 \n",
       "Q 56.984375 17.96875 50.515625 8.265625 \n",
       "Q 44.046875 -1.421875 31.78125 -1.421875 \n",
       "Q 19.53125 -1.421875 13.0625 8.265625 \n",
       "Q 6.59375 17.96875 6.59375 36.375 \n",
       "Q 6.59375 54.828125 13.0625 64.515625 \n",
       "Q 19.53125 74.21875 31.78125 74.21875 \n",
       "z\n",
       "\" id=\"DejaVuSans-48\"/>\n",
       "       <path d=\"M 10.6875 12.40625 \n",
       "L 21 12.40625 \n",
       "L 21 0 \n",
       "L 10.6875 0 \n",
       "z\n",
       "\" id=\"DejaVuSans-46\"/>\n",
       "      </defs>\n",
       "      <g transform=\"translate(48.554688 161.497656)scale(0.1 -0.1)\">\n",
       "       <use xlink:href=\"#DejaVuSans-48\"/>\n",
       "       <use x=\"63.623047\" xlink:href=\"#DejaVuSans-46\"/>\n",
       "       <use x=\"95.410156\" xlink:href=\"#DejaVuSans-48\"/>\n",
       "      </g>\n",
       "     </g>\n",
       "    </g>\n",
       "    <g id=\"xtick_2\">\n",
       "     <g id=\"line2d_3\">\n",
       "      <path clip-path=\"url(#pbe888d7fac)\" d=\"M 105.33125 146.899219 \n",
       "L 105.33125 10.999219 \n",
       "\" style=\"fill:none;stroke:#b0b0b0;stroke-linecap:square;stroke-width:0.8;\"/>\n",
       "     </g>\n",
       "     <g id=\"line2d_4\">\n",
       "      <g>\n",
       "       <use style=\"stroke:#000000;stroke-width:0.8;\" x=\"105.33125\" xlink:href=\"#mc9a12aad56\" y=\"146.899219\"/>\n",
       "      </g>\n",
       "     </g>\n",
       "     <g id=\"text_2\">\n",
       "      <!-- 0.5 -->\n",
       "      <defs>\n",
       "       <path d=\"M 10.796875 72.90625 \n",
       "L 49.515625 72.90625 \n",
       "L 49.515625 64.59375 \n",
       "L 19.828125 64.59375 \n",
       "L 19.828125 46.734375 \n",
       "Q 21.96875 47.46875 24.109375 47.828125 \n",
       "Q 26.265625 48.1875 28.421875 48.1875 \n",
       "Q 40.625 48.1875 47.75 41.5 \n",
       "Q 54.890625 34.8125 54.890625 23.390625 \n",
       "Q 54.890625 11.625 47.5625 5.09375 \n",
       "Q 40.234375 -1.421875 26.90625 -1.421875 \n",
       "Q 22.3125 -1.421875 17.546875 -0.640625 \n",
       "Q 12.796875 0.140625 7.71875 1.703125 \n",
       "L 7.71875 11.625 \n",
       "Q 12.109375 9.234375 16.796875 8.0625 \n",
       "Q 21.484375 6.890625 26.703125 6.890625 \n",
       "Q 35.15625 6.890625 40.078125 11.328125 \n",
       "Q 45.015625 15.765625 45.015625 23.390625 \n",
       "Q 45.015625 31 40.078125 35.4375 \n",
       "Q 35.15625 39.890625 26.703125 39.890625 \n",
       "Q 22.75 39.890625 18.8125 39.015625 \n",
       "Q 14.890625 38.140625 10.796875 36.28125 \n",
       "z\n",
       "\" id=\"DejaVuSans-53\"/>\n",
       "      </defs>\n",
       "      <g transform=\"translate(97.379688 161.497656)scale(0.1 -0.1)\">\n",
       "       <use xlink:href=\"#DejaVuSans-48\"/>\n",
       "       <use x=\"63.623047\" xlink:href=\"#DejaVuSans-46\"/>\n",
       "       <use x=\"95.410156\" xlink:href=\"#DejaVuSans-53\"/>\n",
       "      </g>\n",
       "     </g>\n",
       "    </g>\n",
       "    <g id=\"xtick_3\">\n",
       "     <g id=\"line2d_5\">\n",
       "      <path clip-path=\"url(#pbe888d7fac)\" d=\"M 154.15625 146.899219 \n",
       "L 154.15625 10.999219 \n",
       "\" style=\"fill:none;stroke:#b0b0b0;stroke-linecap:square;stroke-width:0.8;\"/>\n",
       "     </g>\n",
       "     <g id=\"line2d_6\">\n",
       "      <g>\n",
       "       <use style=\"stroke:#000000;stroke-width:0.8;\" x=\"154.15625\" xlink:href=\"#mc9a12aad56\" y=\"146.899219\"/>\n",
       "      </g>\n",
       "     </g>\n",
       "     <g id=\"text_3\">\n",
       "      <!-- 1.0 -->\n",
       "      <defs>\n",
       "       <path d=\"M 12.40625 8.296875 \n",
       "L 28.515625 8.296875 \n",
       "L 28.515625 63.921875 \n",
       "L 10.984375 60.40625 \n",
       "L 10.984375 69.390625 \n",
       "L 28.421875 72.90625 \n",
       "L 38.28125 72.90625 \n",
       "L 38.28125 8.296875 \n",
       "L 54.390625 8.296875 \n",
       "L 54.390625 0 \n",
       "L 12.40625 0 \n",
       "z\n",
       "\" id=\"DejaVuSans-49\"/>\n",
       "      </defs>\n",
       "      <g transform=\"translate(146.204688 161.497656)scale(0.1 -0.1)\">\n",
       "       <use xlink:href=\"#DejaVuSans-49\"/>\n",
       "       <use x=\"63.623047\" xlink:href=\"#DejaVuSans-46\"/>\n",
       "       <use x=\"95.410156\" xlink:href=\"#DejaVuSans-48\"/>\n",
       "      </g>\n",
       "     </g>\n",
       "    </g>\n",
       "    <g id=\"xtick_4\">\n",
       "     <g id=\"line2d_7\">\n",
       "      <path clip-path=\"url(#pbe888d7fac)\" d=\"M 202.98125 146.899219 \n",
       "L 202.98125 10.999219 \n",
       "\" style=\"fill:none;stroke:#b0b0b0;stroke-linecap:square;stroke-width:0.8;\"/>\n",
       "     </g>\n",
       "     <g id=\"line2d_8\">\n",
       "      <g>\n",
       "       <use style=\"stroke:#000000;stroke-width:0.8;\" x=\"202.98125\" xlink:href=\"#mc9a12aad56\" y=\"146.899219\"/>\n",
       "      </g>\n",
       "     </g>\n",
       "     <g id=\"text_4\">\n",
       "      <!-- 1.5 -->\n",
       "      <g transform=\"translate(195.029688 161.497656)scale(0.1 -0.1)\">\n",
       "       <use xlink:href=\"#DejaVuSans-49\"/>\n",
       "       <use x=\"63.623047\" xlink:href=\"#DejaVuSans-46\"/>\n",
       "       <use x=\"95.410156\" xlink:href=\"#DejaVuSans-53\"/>\n",
       "      </g>\n",
       "     </g>\n",
       "    </g>\n",
       "    <g id=\"xtick_5\">\n",
       "     <g id=\"line2d_9\">\n",
       "      <path clip-path=\"url(#pbe888d7fac)\" d=\"M 251.80625 146.899219 \n",
       "L 251.80625 10.999219 \n",
       "\" style=\"fill:none;stroke:#b0b0b0;stroke-linecap:square;stroke-width:0.8;\"/>\n",
       "     </g>\n",
       "     <g id=\"line2d_10\">\n",
       "      <g>\n",
       "       <use style=\"stroke:#000000;stroke-width:0.8;\" x=\"251.80625\" xlink:href=\"#mc9a12aad56\" y=\"146.899219\"/>\n",
       "      </g>\n",
       "     </g>\n",
       "     <g id=\"text_5\">\n",
       "      <!-- 2.0 -->\n",
       "      <defs>\n",
       "       <path d=\"M 19.1875 8.296875 \n",
       "L 53.609375 8.296875 \n",
       "L 53.609375 0 \n",
       "L 7.328125 0 \n",
       "L 7.328125 8.296875 \n",
       "Q 12.9375 14.109375 22.625 23.890625 \n",
       "Q 32.328125 33.6875 34.8125 36.53125 \n",
       "Q 39.546875 41.84375 41.421875 45.53125 \n",
       "Q 43.3125 49.21875 43.3125 52.78125 \n",
       "Q 43.3125 58.59375 39.234375 62.25 \n",
       "Q 35.15625 65.921875 28.609375 65.921875 \n",
       "Q 23.96875 65.921875 18.8125 64.3125 \n",
       "Q 13.671875 62.703125 7.8125 59.421875 \n",
       "L 7.8125 69.390625 \n",
       "Q 13.765625 71.78125 18.9375 73 \n",
       "Q 24.125 74.21875 28.421875 74.21875 \n",
       "Q 39.75 74.21875 46.484375 68.546875 \n",
       "Q 53.21875 62.890625 53.21875 53.421875 \n",
       "Q 53.21875 48.921875 51.53125 44.890625 \n",
       "Q 49.859375 40.875 45.40625 35.40625 \n",
       "Q 44.1875 33.984375 37.640625 27.21875 \n",
       "Q 31.109375 20.453125 19.1875 8.296875 \n",
       "z\n",
       "\" id=\"DejaVuSans-50\"/>\n",
       "      </defs>\n",
       "      <g transform=\"translate(243.854688 161.497656)scale(0.1 -0.1)\">\n",
       "       <use xlink:href=\"#DejaVuSans-50\"/>\n",
       "       <use x=\"63.623047\" xlink:href=\"#DejaVuSans-46\"/>\n",
       "       <use x=\"95.410156\" xlink:href=\"#DejaVuSans-48\"/>\n",
       "      </g>\n",
       "     </g>\n",
       "    </g>\n",
       "    <g id=\"text_6\">\n",
       "     <!-- epoch -->\n",
       "     <defs>\n",
       "      <path d=\"M 56.203125 29.59375 \n",
       "L 56.203125 25.203125 \n",
       "L 14.890625 25.203125 \n",
       "Q 15.484375 15.921875 20.484375 11.0625 \n",
       "Q 25.484375 6.203125 34.421875 6.203125 \n",
       "Q 39.59375 6.203125 44.453125 7.46875 \n",
       "Q 49.3125 8.734375 54.109375 11.28125 \n",
       "L 54.109375 2.78125 \n",
       "Q 49.265625 0.734375 44.1875 -0.34375 \n",
       "Q 39.109375 -1.421875 33.890625 -1.421875 \n",
       "Q 20.796875 -1.421875 13.15625 6.1875 \n",
       "Q 5.515625 13.8125 5.515625 26.8125 \n",
       "Q 5.515625 40.234375 12.765625 48.109375 \n",
       "Q 20.015625 56 32.328125 56 \n",
       "Q 43.359375 56 49.78125 48.890625 \n",
       "Q 56.203125 41.796875 56.203125 29.59375 \n",
       "z\n",
       "M 47.21875 32.234375 \n",
       "Q 47.125 39.59375 43.09375 43.984375 \n",
       "Q 39.0625 48.390625 32.421875 48.390625 \n",
       "Q 24.90625 48.390625 20.390625 44.140625 \n",
       "Q 15.875 39.890625 15.1875 32.171875 \n",
       "z\n",
       "\" id=\"DejaVuSans-101\"/>\n",
       "      <path d=\"M 18.109375 8.203125 \n",
       "L 18.109375 -20.796875 \n",
       "L 9.078125 -20.796875 \n",
       "L 9.078125 54.6875 \n",
       "L 18.109375 54.6875 \n",
       "L 18.109375 46.390625 \n",
       "Q 20.953125 51.265625 25.265625 53.625 \n",
       "Q 29.59375 56 35.59375 56 \n",
       "Q 45.5625 56 51.78125 48.09375 \n",
       "Q 58.015625 40.1875 58.015625 27.296875 \n",
       "Q 58.015625 14.40625 51.78125 6.484375 \n",
       "Q 45.5625 -1.421875 35.59375 -1.421875 \n",
       "Q 29.59375 -1.421875 25.265625 0.953125 \n",
       "Q 20.953125 3.328125 18.109375 8.203125 \n",
       "z\n",
       "M 48.6875 27.296875 \n",
       "Q 48.6875 37.203125 44.609375 42.84375 \n",
       "Q 40.53125 48.484375 33.40625 48.484375 \n",
       "Q 26.265625 48.484375 22.1875 42.84375 \n",
       "Q 18.109375 37.203125 18.109375 27.296875 \n",
       "Q 18.109375 17.390625 22.1875 11.75 \n",
       "Q 26.265625 6.109375 33.40625 6.109375 \n",
       "Q 40.53125 6.109375 44.609375 11.75 \n",
       "Q 48.6875 17.390625 48.6875 27.296875 \n",
       "z\n",
       "\" id=\"DejaVuSans-112\"/>\n",
       "      <path d=\"M 30.609375 48.390625 \n",
       "Q 23.390625 48.390625 19.1875 42.75 \n",
       "Q 14.984375 37.109375 14.984375 27.296875 \n",
       "Q 14.984375 17.484375 19.15625 11.84375 \n",
       "Q 23.34375 6.203125 30.609375 6.203125 \n",
       "Q 37.796875 6.203125 41.984375 11.859375 \n",
       "Q 46.1875 17.53125 46.1875 27.296875 \n",
       "Q 46.1875 37.015625 41.984375 42.703125 \n",
       "Q 37.796875 48.390625 30.609375 48.390625 \n",
       "z\n",
       "M 30.609375 56 \n",
       "Q 42.328125 56 49.015625 48.375 \n",
       "Q 55.71875 40.765625 55.71875 27.296875 \n",
       "Q 55.71875 13.875 49.015625 6.21875 \n",
       "Q 42.328125 -1.421875 30.609375 -1.421875 \n",
       "Q 18.84375 -1.421875 12.171875 6.21875 \n",
       "Q 5.515625 13.875 5.515625 27.296875 \n",
       "Q 5.515625 40.765625 12.171875 48.375 \n",
       "Q 18.84375 56 30.609375 56 \n",
       "z\n",
       "\" id=\"DejaVuSans-111\"/>\n",
       "      <path d=\"M 48.78125 52.59375 \n",
       "L 48.78125 44.1875 \n",
       "Q 44.96875 46.296875 41.140625 47.34375 \n",
       "Q 37.3125 48.390625 33.40625 48.390625 \n",
       "Q 24.65625 48.390625 19.8125 42.84375 \n",
       "Q 14.984375 37.3125 14.984375 27.296875 \n",
       "Q 14.984375 17.28125 19.8125 11.734375 \n",
       "Q 24.65625 6.203125 33.40625 6.203125 \n",
       "Q 37.3125 6.203125 41.140625 7.25 \n",
       "Q 44.96875 8.296875 48.78125 10.40625 \n",
       "L 48.78125 2.09375 \n",
       "Q 45.015625 0.34375 40.984375 -0.53125 \n",
       "Q 36.96875 -1.421875 32.421875 -1.421875 \n",
       "Q 20.0625 -1.421875 12.78125 6.34375 \n",
       "Q 5.515625 14.109375 5.515625 27.296875 \n",
       "Q 5.515625 40.671875 12.859375 48.328125 \n",
       "Q 20.21875 56 33.015625 56 \n",
       "Q 37.15625 56 41.109375 55.140625 \n",
       "Q 45.0625 54.296875 48.78125 52.59375 \n",
       "z\n",
       "\" id=\"DejaVuSans-99\"/>\n",
       "      <path d=\"M 54.890625 33.015625 \n",
       "L 54.890625 0 \n",
       "L 45.90625 0 \n",
       "L 45.90625 32.71875 \n",
       "Q 45.90625 40.484375 42.875 44.328125 \n",
       "Q 39.84375 48.1875 33.796875 48.1875 \n",
       "Q 26.515625 48.1875 22.3125 43.546875 \n",
       "Q 18.109375 38.921875 18.109375 30.90625 \n",
       "L 18.109375 0 \n",
       "L 9.078125 0 \n",
       "L 9.078125 75.984375 \n",
       "L 18.109375 75.984375 \n",
       "L 18.109375 46.1875 \n",
       "Q 21.34375 51.125 25.703125 53.5625 \n",
       "Q 30.078125 56 35.796875 56 \n",
       "Q 45.21875 56 50.046875 50.171875 \n",
       "Q 54.890625 44.34375 54.890625 33.015625 \n",
       "z\n",
       "\" id=\"DejaVuSans-104\"/>\n",
       "     </defs>\n",
       "     <g transform=\"translate(138.928125 175.175781)scale(0.1 -0.1)\">\n",
       "      <use xlink:href=\"#DejaVuSans-101\"/>\n",
       "      <use x=\"61.523438\" xlink:href=\"#DejaVuSans-112\"/>\n",
       "      <use x=\"125\" xlink:href=\"#DejaVuSans-111\"/>\n",
       "      <use x=\"186.181641\" xlink:href=\"#DejaVuSans-99\"/>\n",
       "      <use x=\"241.162109\" xlink:href=\"#DejaVuSans-104\"/>\n",
       "     </g>\n",
       "    </g>\n",
       "   </g>\n",
       "   <g id=\"matplotlib.axis_2\">\n",
       "    <g id=\"ytick_1\">\n",
       "     <g id=\"line2d_11\">\n",
       "      <path clip-path=\"url(#pbe888d7fac)\" d=\"M 56.50625 141.672296 \n",
       "L 251.80625 141.672296 \n",
       "\" style=\"fill:none;stroke:#b0b0b0;stroke-linecap:square;stroke-width:0.8;\"/>\n",
       "     </g>\n",
       "     <g id=\"line2d_12\">\n",
       "      <defs>\n",
       "       <path d=\"M 0 0 \n",
       "L -3.5 0 \n",
       "\" id=\"m4b06409394\" style=\"stroke:#000000;stroke-width:0.8;\"/>\n",
       "      </defs>\n",
       "      <g>\n",
       "       <use style=\"stroke:#000000;stroke-width:0.8;\" x=\"56.50625\" xlink:href=\"#m4b06409394\" y=\"141.672296\"/>\n",
       "      </g>\n",
       "     </g>\n",
       "     <g id=\"text_7\">\n",
       "      <!-- 0.225 -->\n",
       "      <g transform=\"translate(20.878125 145.471514)scale(0.1 -0.1)\">\n",
       "       <use xlink:href=\"#DejaVuSans-48\"/>\n",
       "       <use x=\"63.623047\" xlink:href=\"#DejaVuSans-46\"/>\n",
       "       <use x=\"95.410156\" xlink:href=\"#DejaVuSans-50\"/>\n",
       "       <use x=\"159.033203\" xlink:href=\"#DejaVuSans-50\"/>\n",
       "       <use x=\"222.65625\" xlink:href=\"#DejaVuSans-53\"/>\n",
       "      </g>\n",
       "     </g>\n",
       "    </g>\n",
       "    <g id=\"ytick_2\">\n",
       "     <g id=\"line2d_13\">\n",
       "      <path clip-path=\"url(#pbe888d7fac)\" d=\"M 56.50625 115.53768 \n",
       "L 251.80625 115.53768 \n",
       "\" style=\"fill:none;stroke:#b0b0b0;stroke-linecap:square;stroke-width:0.8;\"/>\n",
       "     </g>\n",
       "     <g id=\"line2d_14\">\n",
       "      <g>\n",
       "       <use style=\"stroke:#000000;stroke-width:0.8;\" x=\"56.50625\" xlink:href=\"#m4b06409394\" y=\"115.53768\"/>\n",
       "      </g>\n",
       "     </g>\n",
       "     <g id=\"text_8\">\n",
       "      <!-- 0.250 -->\n",
       "      <g transform=\"translate(20.878125 119.336899)scale(0.1 -0.1)\">\n",
       "       <use xlink:href=\"#DejaVuSans-48\"/>\n",
       "       <use x=\"63.623047\" xlink:href=\"#DejaVuSans-46\"/>\n",
       "       <use x=\"95.410156\" xlink:href=\"#DejaVuSans-50\"/>\n",
       "       <use x=\"159.033203\" xlink:href=\"#DejaVuSans-53\"/>\n",
       "       <use x=\"222.65625\" xlink:href=\"#DejaVuSans-48\"/>\n",
       "      </g>\n",
       "     </g>\n",
       "    </g>\n",
       "    <g id=\"ytick_3\">\n",
       "     <g id=\"line2d_15\">\n",
       "      <path clip-path=\"url(#pbe888d7fac)\" d=\"M 56.50625 89.403065 \n",
       "L 251.80625 89.403065 \n",
       "\" style=\"fill:none;stroke:#b0b0b0;stroke-linecap:square;stroke-width:0.8;\"/>\n",
       "     </g>\n",
       "     <g id=\"line2d_16\">\n",
       "      <g>\n",
       "       <use style=\"stroke:#000000;stroke-width:0.8;\" x=\"56.50625\" xlink:href=\"#m4b06409394\" y=\"89.403065\"/>\n",
       "      </g>\n",
       "     </g>\n",
       "     <g id=\"text_9\">\n",
       "      <!-- 0.275 -->\n",
       "      <defs>\n",
       "       <path d=\"M 8.203125 72.90625 \n",
       "L 55.078125 72.90625 \n",
       "L 55.078125 68.703125 \n",
       "L 28.609375 0 \n",
       "L 18.3125 0 \n",
       "L 43.21875 64.59375 \n",
       "L 8.203125 64.59375 \n",
       "z\n",
       "\" id=\"DejaVuSans-55\"/>\n",
       "      </defs>\n",
       "      <g transform=\"translate(20.878125 93.202284)scale(0.1 -0.1)\">\n",
       "       <use xlink:href=\"#DejaVuSans-48\"/>\n",
       "       <use x=\"63.623047\" xlink:href=\"#DejaVuSans-46\"/>\n",
       "       <use x=\"95.410156\" xlink:href=\"#DejaVuSans-50\"/>\n",
       "       <use x=\"159.033203\" xlink:href=\"#DejaVuSans-55\"/>\n",
       "       <use x=\"222.65625\" xlink:href=\"#DejaVuSans-53\"/>\n",
       "      </g>\n",
       "     </g>\n",
       "    </g>\n",
       "    <g id=\"ytick_4\">\n",
       "     <g id=\"line2d_17\">\n",
       "      <path clip-path=\"url(#pbe888d7fac)\" d=\"M 56.50625 63.26845 \n",
       "L 251.80625 63.26845 \n",
       "\" style=\"fill:none;stroke:#b0b0b0;stroke-linecap:square;stroke-width:0.8;\"/>\n",
       "     </g>\n",
       "     <g id=\"line2d_18\">\n",
       "      <g>\n",
       "       <use style=\"stroke:#000000;stroke-width:0.8;\" x=\"56.50625\" xlink:href=\"#m4b06409394\" y=\"63.26845\"/>\n",
       "      </g>\n",
       "     </g>\n",
       "     <g id=\"text_10\">\n",
       "      <!-- 0.300 -->\n",
       "      <defs>\n",
       "       <path d=\"M 40.578125 39.3125 \n",
       "Q 47.65625 37.796875 51.625 33 \n",
       "Q 55.609375 28.21875 55.609375 21.1875 \n",
       "Q 55.609375 10.40625 48.1875 4.484375 \n",
       "Q 40.765625 -1.421875 27.09375 -1.421875 \n",
       "Q 22.515625 -1.421875 17.65625 -0.515625 \n",
       "Q 12.796875 0.390625 7.625 2.203125 \n",
       "L 7.625 11.71875 \n",
       "Q 11.71875 9.328125 16.59375 8.109375 \n",
       "Q 21.484375 6.890625 26.8125 6.890625 \n",
       "Q 36.078125 6.890625 40.9375 10.546875 \n",
       "Q 45.796875 14.203125 45.796875 21.1875 \n",
       "Q 45.796875 27.640625 41.28125 31.265625 \n",
       "Q 36.765625 34.90625 28.71875 34.90625 \n",
       "L 20.21875 34.90625 \n",
       "L 20.21875 43.015625 \n",
       "L 29.109375 43.015625 \n",
       "Q 36.375 43.015625 40.234375 45.921875 \n",
       "Q 44.09375 48.828125 44.09375 54.296875 \n",
       "Q 44.09375 59.90625 40.109375 62.90625 \n",
       "Q 36.140625 65.921875 28.71875 65.921875 \n",
       "Q 24.65625 65.921875 20.015625 65.03125 \n",
       "Q 15.375 64.15625 9.8125 62.3125 \n",
       "L 9.8125 71.09375 \n",
       "Q 15.4375 72.65625 20.34375 73.4375 \n",
       "Q 25.25 74.21875 29.59375 74.21875 \n",
       "Q 40.828125 74.21875 47.359375 69.109375 \n",
       "Q 53.90625 64.015625 53.90625 55.328125 \n",
       "Q 53.90625 49.265625 50.4375 45.09375 \n",
       "Q 46.96875 40.921875 40.578125 39.3125 \n",
       "z\n",
       "\" id=\"DejaVuSans-51\"/>\n",
       "      </defs>\n",
       "      <g transform=\"translate(20.878125 67.067668)scale(0.1 -0.1)\">\n",
       "       <use xlink:href=\"#DejaVuSans-48\"/>\n",
       "       <use x=\"63.623047\" xlink:href=\"#DejaVuSans-46\"/>\n",
       "       <use x=\"95.410156\" xlink:href=\"#DejaVuSans-51\"/>\n",
       "       <use x=\"159.033203\" xlink:href=\"#DejaVuSans-48\"/>\n",
       "       <use x=\"222.65625\" xlink:href=\"#DejaVuSans-48\"/>\n",
       "      </g>\n",
       "     </g>\n",
       "    </g>\n",
       "    <g id=\"ytick_5\">\n",
       "     <g id=\"line2d_19\">\n",
       "      <path clip-path=\"url(#pbe888d7fac)\" d=\"M 56.50625 37.133834 \n",
       "L 251.80625 37.133834 \n",
       "\" style=\"fill:none;stroke:#b0b0b0;stroke-linecap:square;stroke-width:0.8;\"/>\n",
       "     </g>\n",
       "     <g id=\"line2d_20\">\n",
       "      <g>\n",
       "       <use style=\"stroke:#000000;stroke-width:0.8;\" x=\"56.50625\" xlink:href=\"#m4b06409394\" y=\"37.133834\"/>\n",
       "      </g>\n",
       "     </g>\n",
       "     <g id=\"text_11\">\n",
       "      <!-- 0.325 -->\n",
       "      <g transform=\"translate(20.878125 40.933053)scale(0.1 -0.1)\">\n",
       "       <use xlink:href=\"#DejaVuSans-48\"/>\n",
       "       <use x=\"63.623047\" xlink:href=\"#DejaVuSans-46\"/>\n",
       "       <use x=\"95.410156\" xlink:href=\"#DejaVuSans-51\"/>\n",
       "       <use x=\"159.033203\" xlink:href=\"#DejaVuSans-50\"/>\n",
       "       <use x=\"222.65625\" xlink:href=\"#DejaVuSans-53\"/>\n",
       "      </g>\n",
       "     </g>\n",
       "    </g>\n",
       "    <g id=\"ytick_6\">\n",
       "     <g id=\"line2d_21\">\n",
       "      <path clip-path=\"url(#pbe888d7fac)\" d=\"M 56.50625 10.999219 \n",
       "L 251.80625 10.999219 \n",
       "\" style=\"fill:none;stroke:#b0b0b0;stroke-linecap:square;stroke-width:0.8;\"/>\n",
       "     </g>\n",
       "     <g id=\"line2d_22\">\n",
       "      <g>\n",
       "       <use style=\"stroke:#000000;stroke-width:0.8;\" x=\"56.50625\" xlink:href=\"#m4b06409394\" y=\"10.999219\"/>\n",
       "      </g>\n",
       "     </g>\n",
       "     <g id=\"text_12\">\n",
       "      <!-- 0.350 -->\n",
       "      <g transform=\"translate(20.878125 14.798437)scale(0.1 -0.1)\">\n",
       "       <use xlink:href=\"#DejaVuSans-48\"/>\n",
       "       <use x=\"63.623047\" xlink:href=\"#DejaVuSans-46\"/>\n",
       "       <use x=\"95.410156\" xlink:href=\"#DejaVuSans-51\"/>\n",
       "       <use x=\"159.033203\" xlink:href=\"#DejaVuSans-53\"/>\n",
       "       <use x=\"222.65625\" xlink:href=\"#DejaVuSans-48\"/>\n",
       "      </g>\n",
       "     </g>\n",
       "    </g>\n",
       "    <g id=\"text_13\">\n",
       "     <!-- loss -->\n",
       "     <defs>\n",
       "      <path d=\"M 9.421875 75.984375 \n",
       "L 18.40625 75.984375 \n",
       "L 18.40625 0 \n",
       "L 9.421875 0 \n",
       "z\n",
       "\" id=\"DejaVuSans-108\"/>\n",
       "      <path d=\"M 44.28125 53.078125 \n",
       "L 44.28125 44.578125 \n",
       "Q 40.484375 46.53125 36.375 47.5 \n",
       "Q 32.28125 48.484375 27.875 48.484375 \n",
       "Q 21.1875 48.484375 17.84375 46.4375 \n",
       "Q 14.5 44.390625 14.5 40.28125 \n",
       "Q 14.5 37.15625 16.890625 35.375 \n",
       "Q 19.28125 33.59375 26.515625 31.984375 \n",
       "L 29.59375 31.296875 \n",
       "Q 39.15625 29.25 43.1875 25.515625 \n",
       "Q 47.21875 21.78125 47.21875 15.09375 \n",
       "Q 47.21875 7.46875 41.1875 3.015625 \n",
       "Q 35.15625 -1.421875 24.609375 -1.421875 \n",
       "Q 20.21875 -1.421875 15.453125 -0.5625 \n",
       "Q 10.6875 0.296875 5.421875 2 \n",
       "L 5.421875 11.28125 \n",
       "Q 10.40625 8.6875 15.234375 7.390625 \n",
       "Q 20.0625 6.109375 24.8125 6.109375 \n",
       "Q 31.15625 6.109375 34.5625 8.28125 \n",
       "Q 37.984375 10.453125 37.984375 14.40625 \n",
       "Q 37.984375 18.0625 35.515625 20.015625 \n",
       "Q 33.0625 21.96875 24.703125 23.78125 \n",
       "L 21.578125 24.515625 \n",
       "Q 13.234375 26.265625 9.515625 29.90625 \n",
       "Q 5.8125 33.546875 5.8125 39.890625 \n",
       "Q 5.8125 47.609375 11.28125 51.796875 \n",
       "Q 16.75 56 26.8125 56 \n",
       "Q 31.78125 56 36.171875 55.265625 \n",
       "Q 40.578125 54.546875 44.28125 53.078125 \n",
       "z\n",
       "\" id=\"DejaVuSans-115\"/>\n",
       "     </defs>\n",
       "     <g transform=\"translate(14.798438 88.607031)rotate(-90)scale(0.1 -0.1)\">\n",
       "      <use xlink:href=\"#DejaVuSans-108\"/>\n",
       "      <use x=\"27.783203\" xlink:href=\"#DejaVuSans-111\"/>\n",
       "      <use x=\"88.964844\" xlink:href=\"#DejaVuSans-115\"/>\n",
       "      <use x=\"141.064453\" xlink:href=\"#DejaVuSans-115\"/>\n",
       "     </g>\n",
       "    </g>\n",
       "   </g>\n",
       "   <g id=\"line2d_23\">\n",
       "    <path clip-path=\"url(#pbe888d7fac)\" d=\"M 69.52625 10.628374 \n",
       "L 82.54625 62.316278 \n",
       "L 95.56625 81.972161 \n",
       "L 108.58625 105.988003 \n",
       "L 121.60625 111.525268 \n",
       "L 134.62625 118.004895 \n",
       "L 147.64625 119.792632 \n",
       "L 160.66625 116.890894 \n",
       "L 173.68625 122.420782 \n",
       "L 186.70625 121.036899 \n",
       "L 199.72625 120.498906 \n",
       "L 212.74625 122.134688 \n",
       "L 225.76625 121.890286 \n",
       "L 238.78625 121.111767 \n",
       "L 251.80625 123.508055 \n",
       "\" style=\"fill:none;stroke:#1f77b4;stroke-linecap:square;stroke-width:1.5;\"/>\n",
       "   </g>\n",
       "   <g id=\"patch_3\">\n",
       "    <path d=\"M 56.50625 146.899219 \n",
       "L 56.50625 10.999219 \n",
       "\" style=\"fill:none;stroke:#000000;stroke-linecap:square;stroke-linejoin:miter;stroke-width:0.8;\"/>\n",
       "   </g>\n",
       "   <g id=\"patch_4\">\n",
       "    <path d=\"M 251.80625 146.899219 \n",
       "L 251.80625 10.999219 \n",
       "\" style=\"fill:none;stroke:#000000;stroke-linecap:square;stroke-linejoin:miter;stroke-width:0.8;\"/>\n",
       "   </g>\n",
       "   <g id=\"patch_5\">\n",
       "    <path d=\"M 56.50625 146.899219 \n",
       "L 251.80625 146.899219 \n",
       "\" style=\"fill:none;stroke:#000000;stroke-linecap:square;stroke-linejoin:miter;stroke-width:0.8;\"/>\n",
       "   </g>\n",
       "   <g id=\"patch_6\">\n",
       "    <path d=\"M 56.50625 10.999219 \n",
       "L 251.80625 10.999219 \n",
       "\" style=\"fill:none;stroke:#000000;stroke-linecap:square;stroke-linejoin:miter;stroke-width:0.8;\"/>\n",
       "   </g>\n",
       "  </g>\n",
       " </g>\n",
       " <defs>\n",
       "  <clipPath id=\"pbe888d7fac\">\n",
       "   <rect height=\"135.9\" width=\"195.3\" x=\"56.50625\" y=\"10.999219\"/>\n",
       "  </clipPath>\n",
       " </defs>\n",
       "</svg>\n"
      ],
      "text/plain": [
       "<Figure size 252x180 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "data_iter, feature_dim = d2l.get_data_ch10(batch_size=10)\n",
    "\n",
    "d2l.train_ch10(torch.optim.RMSprop, {'lr': 0.01, 'gamma': 0.9}, data_iter, feature_dim)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Summary\n",
    "\n",
    "* The difference between RMSProp and Adagrad is that RMSProp uses an EWMA on the squares of elements in the mini-batch stochastic gradient to adjust the learning rate.\n",
    "\n",
    "## Exercises\n",
    "\n",
    "* What happens to the experimental results if we set the value of $γ$ to 1? Why?\n",
    "\n",
    "* Try using other combinations of initial learning rates and γ hyperparameters and observe and ana-lyze the experimental results.\n",
    "\n",
    "## Reference\n",
    "\n",
    "[1] Tieleman, T., & Hinton, G. (2012). Lecture 6.5-rmsprop: Divide the gradient by a running average of\n",
    "its recent magnitude. COURSERA: Neural networks for machine learning, 4(2), 26-31."
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
